{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import json\n",
    "from openai import OpenAI\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "# Set your OpenAI API key\n",
    "os.environ['OPENAI_API_KEY'] = YOUR_API_KEY\n",
    "client = OpenAI()\n",
    "\n",
    "def check_match_with_gpt(question, ground_truth, predicted):\n",
    "    # Construct the prompt for GPT-4\n",
    "    prompt = f\"Question: {question}\\nGround Truth Answer: {ground_truth}\\nPredicted Answer: {predicted}\\nDoes the predicted answer match the ground truth? Answer 1 for match and 0 for not match. Use semantic meaning not exact match. Synonyms are also treated as a match, e.g., football and soccer, playground and ground track field, building and rooftop, pond and swimming pool. Do not explain the reason.\\n\"\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "        # model=\"gpt-3.5-turbo-1106\",\n",
    "        model=\"gpt-4o-mini\",\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"user\", \n",
    "                \"content\": [\n",
    "                    {\n",
    "                        \"type\": \"text\", \n",
    "                        \"text\": prompt,\n",
    "                    },\n",
    "                ]\n",
    "            }\n",
    "        ],\n",
    "        max_tokens=100,\n",
    "    )\n",
    "\n",
    "    # answer = response.choices[0].text.strip()\n",
    "    answer =  response.choices[0].message.content\n",
    "    \n",
    "    return answer\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VRSBench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qa_list = [json.loads(line) for line in open('Results/outputs_RSBench_new/rsbench_vqa_v2.json','r').readlines()]\n",
    "\n",
    "# Iterate over the list and check matches\n",
    "results = []\n",
    "f = open('Results/outputs_RSBench_new/rsbench_vqa_v2_gpt.json', 'w') \n",
    "for ii, qa in enumerate(tqdm(qa_list[37342:])):\n",
    "    question = qa['question']\n",
    "    ground_truth = qa['ground_truth'].lower()\n",
    "    predicted = qa['answer'].lower()\n",
    "    if ground_truth in predicted:\n",
    "        match_result = '1'\n",
    "    elif ground_truth in ['yes', 'no'] + list(map(str, range(100))):\n",
    "        match_result = '1' if ground_truth == predicted else '0'\n",
    "    elif 'correct' not in qa or qa['correct'] not in ['1', '0']:\n",
    "        match_result = check_match_with_gpt(question, ground_truth, predicted)\n",
    "    else:\n",
    "        match_result = qa['correct']\n",
    "        \n",
    "    result = {\n",
    "        'question_id': qa['question_id'],\n",
    "        'image_id': qa['image_id'],\n",
    "        \"type\": qa['type'],\n",
    "        \"question\": question,\n",
    "        \"ground_truth\": ground_truth,\n",
    "        \"predicted\": predicted,\n",
    "        \"correct\": match_result,\n",
    "    }\n",
    "    results.append(result)\n",
    "\n",
    "    f.write(json.dumps(result)+'\\n')\n",
    "    f.flush()\n",
    "\n",
    "f.close()\n",
    "for result in results:\n",
    "    if ii>5:\n",
    "        break\n",
    "    print(result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open('Results/outputs_RSBench_new/rsbench_vqa_v2_gpt.json', 'r') \n",
    "results = [json.loads(line) for line in f.readlines()]\n",
    "f.close()\n",
    "correct = sum([int(result['correct']) for result in results if result['correct'] in ['1', '0']])\n",
    "print(f\"Correct: {correct}/{len(results)}:\", correct/len(results))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metrics per types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inflect\n",
    "\n",
    "# Create an engine instance\n",
    "convert = inflect.engine()\n",
    "\n",
    "data_path = 'Results/outputs_MGM-7B_RSBench-new/rsbench_vqa_v2_gpt.json'\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "all_types = ['object category', 'object existence', 'object quantity', 'object color', 'object shape', 'object size', 'object position', 'object direction', 'image', 'scene type', 'reasoning', 'rural or urban']\n",
    "\n",
    "print('number of question types:', len(all_types))\n",
    "\n",
    "all_numbers = [convert.number_to_words(x) for x in range(100)]\n",
    "\n",
    "# create a dict with types as key and value to zero\n",
    "correct_per_type = {k: 0 for k in all_types}\n",
    "total_per_type = {k: 0 for k in all_types}\n",
    "invalid_type = 0\n",
    "skip_qas = 0\n",
    "with open(data_path, 'r') as file:\n",
    "    for line in file:\n",
    "        # Convert JSON string to Python dictionary\n",
    "        item = json.loads(line.strip())\n",
    "        img_id = item['image_id']\n",
    "\n",
    "        gt_ans = item['ground_truth'].lower()\n",
    "        pred_ans = item['predicted'].lower()\n",
    "        \n",
    "        q_type = item['type'].lower()\n",
    "        if q_type == 'image': q_type = 'scene type'\n",
    "        if q_type == 'rural or urban': q_type = 'scene type'\n",
    "\n",
    "        if q_type in all_types:\n",
    "            total_per_type[q_type] += 1\n",
    "        else:\n",
    "            print('unknown type:', q_type)\n",
    "            invalid_type += 1\n",
    "\n",
    "        if item['correct'] == '1':\n",
    "            correct += 1\n",
    "            if q_type in all_types:\n",
    "                correct_per_type[q_type] += 1\n",
    "        \n",
    "        total += 1\n",
    "\n",
    "print('number of questions:', total, 'invalid_type:', invalid_type, 'valid', sum(total_per_type.values()))\n",
    "print('Overall acc:', correct/total * 100)\n",
    "# divide by the number of questions of that type\n",
    "print('##############')\n",
    "acc_list = []\n",
    "for k in all_types:\n",
    "    if total_per_type[k] == 0:\n",
    "        continue\n",
    "    print(f'{k} accuracy: {correct_per_type[k]/total_per_type[k] * 100}, out of {total_per_type[k]}')\n",
    "    acc = correct_per_type[k]/total_per_type[k] * 100\n",
    "    acc_list.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print acc_list in format xx.x for each type, split by & for latex table\n",
    "print(' & '.join([f'{acc:.1f}' for acc in list(acc_list) + [np.mean(acc_list)]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_qas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
